from typing import Any, Callable

from error import *
from symbol_table import SymbolTable
from tokens import *
from ast_node import *
from context import *
from result import RTResult, ParserResult
from type_operate import *
from function import *
from type_operate import String, List


class Interpreter(object):
    """解释器"""

    def visit(self, node, context: Context) -> RTResult | None | Any:
        """
        递归下降算法, 遍历 ast node
        :param node: AST Node
        :param context: 上下文对象
        :return:
        """
        if isinstance(node, NumberNode):
            return self.visitNumberNode(node, context)
        elif isinstance(node, BinOpNode):
            return self.visitBinOpNode(node, context)
        elif isinstance(node, UnaryOpNode):
            return self.visitUnaryOpNode(node, context)
        elif isinstance(node, VarAccessNode):
            return self.visitVarAccessNode(node, context)
        elif isinstance(node, VarAssignNode):
            return self.visitVarAssignNode(node, context)
        elif isinstance(node, IfNode):
            return self.visitIfNode(node, context)
        elif isinstance(node, ForNode):
            return self.visitForNode(node, context)
        elif isinstance(node, WhileNode):
            return self.visitWhileNode(node, context)
        elif isinstance(node, FuncNode):
            return self.visitFuncNode(node, context)
        elif isinstance(node, CallNode):
            return self.visitCallNode(node, context)
        elif isinstance(node, StringNode):
            return self.visitStringNode(node, context)
        elif isinstance(node, ListNode):
            return self.visitListNode(node, context)
        elif isinstance(node, ReturnNode):
            return self.visitReturnNode(node, context)
        elif isinstance(node, BreakNode):
            return self.visitBreakNode(node, context)
        elif isinstance(node, ContinueNode):
            return self.visitContinueNode(node, context)
        else:
            return self.no_visit_method(node, context)

    def no_visit_method(self, node, context: Context):
        """默认处理方法"""
        raise Exception(f'No visit_{type(node).__name}')

    def visitIfNode(self, node: IfNode, context: Context) -> RTResult:
        """处理 if-elif-else 代码块"""
        case_values = []
        res = RTResult()
        for condition, expr, is_multipy_row in node.case:
            condition_value = res.register(self.visit(condition, context))
            if res.should_return():
                case_values.append(None)
                return res

            if condition_value.is_true():
                expr_value = res.register(self.visit(expr, context))
                if res.should_return():
                    return res
                case_values.append(expr_value)
                if is_multipy_row:
                    return res.success(Number.null)
                else:
                    return res.success(expr_value)
            else:
                case_values.append('判断值为 False')

        if node.else_case:
            else_value, is_multipy_row = node.else_case
            else_value = res.register(self.visit(else_value, context))
            if res.should_return():
                return res
            else:
                return res.success(else_value)

        return res.success(case_values)

    def visitVarAssignNode(self, node: VarAssignNode, context: Context) -> RTResult:
        """若当前AST节点为设置变量值, 处理方法"""
        res = RTResult()
        var_name = node.var_name_token.value  # 变量名
        value = res.register(self.visit(node.value_node, context))  # 变量值
        if res.should_return():
            return res
        context.symbol_table.set(var_name, value)
        # 深拷贝一下
        if value is not None:
            value = value.copy().set_pos(node.pos_start, node.pos_end)
        return res.success(value)

    def visitBinOpNode(self, node: BinOpNode, context) -> RTResult:
        """若当前AST节点为二元运算符节点, 处理方法"""
        res = RTResult()
        # 左递归
        left = res.register(self.visit(node.left_node, context))
        if res.error:
            return res
        # 右递归
        right = res.register(self.visit(node.right_node, context))
        if res.error:
            return res
        error = None
        result = None
        if node.op_token.type == TT_PLUS:  # 加法
            result, error = left.added_by(right)
        elif node.op_token.type == TT_MINUS:  # 减法
            result, error = left.subbed_by(right)
        elif node.op_token.type == TT_MUL:  # 乘法
            result, error = left.multed_by(right)
        elif node.op_token.type == TT_DIV:  # 除法
            result, error = left.dived_by(right)
        elif node.op_token.type == TT_POW:  # 取幂
            result, error = left.powed_by(right)
        elif node.op_token.type == TT_EE:  # 是否等于
            result, error = left.get_comparison_ee(right)
        elif node.op_token.type == TT_NE:  # 是否不等于
            result, error = left.get_comparison_ne(right)
        elif node.op_token.type == TT_LT:  # 小于
            result, error = left.get_comparison_lt(right)
        elif node.op_token.type == TT_GT:  # 大于
            result, error = left.get_comparison_gt(right)
        elif node.op_token.type == TT_LTE:  # 小于等于
            result, error = left.get_comparison_lte(right)
        elif node.op_token.type == TT_GTE:  # 大于等于
            result, error = left.get_comparison_gte(right)
        elif node.op_token.matches(TT_KEYWORDS, 'and'):  # 或
            result, error = left.and_by(right)
        elif node.op_token.matches(TT_KEYWORDS, 'or'):  # 且
            result, error = left.or_by(right)
        elif node.op_token.matches(TT_KEYWORDS, 'not'):  # 非
            result, error = left.not_by(right)
        else:
            return res.failure(RunTimeError(
                node.pos_start, node.pos_end, f"'{node.op_token.type}' 操作不被支持"
                , context
            ))

        if error:
            return res.failure(error)
        else:
            return res.success(result.set_pos(node.pos_start, node.pos_end))

    def visitUnaryOpNode(self, node: UnaryOpNode, context) -> RTResult:
        """若当前AST节点为一元运算符节点, 处理方法"""
        res = RTResult()
        number = res.register(self.visit(node.node, context))
        if res.should_return():
            return res

        error = None

        if node.op_token.type == TT_MINUS:  # 符号
            number, error = number.multed_by(Number(-1))
        elif node.op_token.matches(TT_KEYWORDS, 'not'):
            number, error = number.not_by()
        if error:
            return res.failure(error)
        else:
            return res.success(number.set_pos(node.pos_start, node.pos_end))

    def visitNumberNode(self, node: NumberNode, context) -> RTResult:
        """若当前AST节点为数字节点, 处理方法"""
        return RTResult().success(
            Number(node.token.value).set_context(context).set_pos(node.pos_start, node.pos_end)
        )

    def visitStringNode(self, node: StringNode, context) -> RTResult:
        """若当前AST节点为字符串节点, 处理方法"""
        return RTResult().success(
            String(node.token.value).set_context(context).set_pos(node.pos_start, node.pos_end)
        )

    def visitVarAccessNode(self, node: VarAccessNode, context: Context) -> RTResult:
        """若当前AST节点为变量获取节点, 处理方法"""
        res = RTResult()
        var_name = node.var_name_token.value  # 变量名
        value = context.symbol_table.get(var_name)  # 变量值
        if not value:
            return res.failure(RunTimeError(
                node.pos_start, node.pos_end,
                f"变量 '{var_name}' 没有被定义", context
            ))
        # 深拷贝一下
        value = value.copy().set_pos(node.pos_start, node.pos_end)
        return res.success(value)

    def visitForNode(self, node: ForNode, context: Context) -> RTResult:
        """for循环对应的node"""
        res = RTResult()
        case_values = ['for 循环值 : ']
        start_value = res.register(self.visit(node.start_value_node, context))
        if res.should_return():
            return res

        end_value = res.register(self.visit(node.end_value_node, context))
        if res.should_return():
            return res

        step_value = Number(1)  # 默认每次循环, 只跳过一个元素
        if node.step_value_node:
            step_value = res.register(self.visit(node.step_value_node, context))
            if res.error:
                return res

        i = start_value.value

        if step_value.value >= 0:
            condition: Callable[[], bool] = lambda: i < end_value.value
        else:
            condition: Callable[[], bool] = lambda: i > end_value.value

        # 开始循环
        while condition():
            context.symbol_table.set(node.var_name_token.value, Number(i))
            i += step_value.value
            # 执行一次循环体
            v = res.register(self.visit(node.body_node, context))
            # 出错了才返回
            if res.should_return() and res.loop_should_break is False and res.loop_should_continue is False:
                return res
            # 跳过这次循环
            if res.loop_should_continue:
                continue
            if res.loop_should_break:
                break
            case_values.append(v)
        if node.should_return_null:
            return res.success_return(Number.null)
        else:
            return res.success(List(case_values).set_pos(node.pos_start, node.pos_end).set_context(context))

    def visitWhileNode(self, node: WhileNode, context: Context) -> RTResult:
        """while循环对应的node"""
        res = RTResult()
        case_values = ['while 循环值 : ']
        while True:
            condition: Number = res.register(self.visit(node.condition_node, context))
            if res.should_return():
                return res

            if not condition.is_true():
                break

            v = res.register(self.visit(node.body_node, context))
            # 出错了才返回
            if res.should_return() and res.loop_should_break is False and res.loop_should_continue is False:
                return res
            # 跳过这次循环
            if res.loop_should_continue:
                continue
            if res.loop_should_break:
                break
            case_values.append(v)
        if node.should_return_null:
            return res.success(Number.null)
        return res.success(List(case_values).set_pos(node.pos_start, node.pos_end).set_context(context))

    def visitFuncNode(self, node: FuncNode, context: Context) -> RTResult:
        """func函数对应的node"""
        res = RTResult()

        func_name = node.var_name_token.value if node.var_name_token else Node
        body_node = node.body_node
        arg_names = [arg_name.value for arg_name in node.arg_name_tokens]

        func_value = Function(func_name, body_node, arg_names, node.should_auto_return).set_context(context).set_pos(
            node.pos_start,
            node.pos_end)

        # 存入上下文
        if node.var_name_token:
            context.symbol_table.set(func_name, func_value)

        return res.success(func_value)

    def visitCallNode(self, node: CallNode, context: Context) -> RTResult:
        """调用func函数对应的node"""
        res = RTResult()
        args = []
        # visitVarAccessNode 方法通过节点名(函数名)从符号表中获取函数(Function)对象本身
        value_to_call: BaseFunction = res.register(self.visit(node.node_to_call, context))

        if res.should_return():
            return res

        value_to_call = value_to_call.copy().set_pos(node.pos_start, node.pos_end).set_context(context)

        for arg_node in node.arg_nodes:
            # 放入的是一个个 Value 对象
            args.append(res.register(self.visit(arg_node, context)))
            if res.should_return():
                return res
        return_value = res.register(value_to_call.execute(args, self))
        if res.should_return():
            return res
        return res.success(return_value)

    def visitListNode(self, node: ListNode, context: Context) -> RTResult:
        """List列表对应的node"""
        res = RTResult()
        elements = []
        for en in node.element_nodes:
            elements.append(res.register(self.visit(en, context)))
            if res.should_return():
                return res
        return res.success(List(elements).set_context(context).set_pos(node.pos_start, node.pos_end))

    def visitReturnNode(self, node: ReturnNode, context) -> RTResult:
        """
        return关键字
        :param node:
        :param context:
        :return:
        """
        res = RTResult()

        if node.node_to_return:
            value = res.register(self.visit(node.node_to_return, context))
            # value 为空, 说明没有返回值, 润了
            if res.should_return():
                return res
        else:
            value = Number.null
        # 返回对应值
        return res.success_return(value)

    def visitContinueNode(self, node, context) -> RTResult:
        """
        toyPL 代码出现 continue 会被 parse 解析，然后会调用 visit_ContinueNode 方法
        该方法会将 loop_should_continue 设置为 True
        :param node:
        :param context:
        :return:
        """
        return RTResult().success_continue()

    def visitBreakNode(self, node, context) -> RTResult:
        """
        与visit_ContinueNode同理
        :param node:
        :param context:
        :return:
        """
        return RTResult().success_break()
